import torch
import einops as ein


def cdist(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    if x.dtype is torch.float16 and x.is_cuda:
        x = ein.rearrange(x, "b l r -> b l () r")
        y = ein.rearrange(y, "b l r -> b () l r")
        return (x - y).norm(dim=-1, p=2)
    return torch.cdist(x, y, p=2)